from typing import Any, List, Optional, Tuple, Callable
import re
import torch
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer, BartForTextInfill, BartTokenizer
from MeaCap.models.clip_utils import CLIP


def summarizer(summarize_model, modality, captions, num_summaries, stop_words, verbose=False, model=None):
    bullet_list_captions = '* "' + '"\n* "'.join(captions) + '"'

    if num_summaries > 1:
        prompt = """
The following are the result of captioning a group of images:

{captions}

I am a machine learning researcher seeking to elucidate the concepts of this group in order to better understand my data.

Come up with {num_summaries} distinct concepts that are likely to be true for this group. Please write a list of captions separated by bullet points ("*"). For example: 
* "a dog next to a horse"
* "a car in the rain"
* "low quality"
* "cars from a side view"
* "people in a intricate dress"
* "a joyful atmosphere"

Do not talk about the caption, e.g., "caption with one word" and do not list more than one concept. Also use singular form unless the concept naturally involves multiple objects.
The hypothesis should be a caption, so hypotheses like "more of ...", "presence of ...", "images with ..." are incorrect. Also do not enumerate possibilities within parentheses. Do not provide multiple options by using 'or' or '/' to maintain clarity. Here are examples of bad outputs and their corrections:
* INCORRECT: "various nature environment like lake, forest, and mountain" CORRECTED: "nature"
* INCORRECT: "a image caption of household object (e.g. bowl, vacuum, lamp)" CORRECTED: "a household object" 
* INCORRECT: "Presence of baby animal" CORRECTED: "a baby animal"
* INCORRECT: "Different types of vehicles including cars, trucks, boats, and RVs" CORRECTED: "a vehicle"
* INCORRECT: "Image caption involving interaction between humans and animals" CORRECTED: "interaction between humans and animals" 
* INCORRECT: "More realistic image" CORRECTED: "realistic image"
* INCORRECT: "Insect (cockroach, dragonfly, grasshopper)" CORRECTED: "a insect"
* INCORRECT: "newspaper or magazine" CORRECTED: "a print media"

Again, I want to identify the characteristics of this group. List properties that hold more often for the images in this group. Answer only with a list (separated by bullet points “*”). Your response: 
                """
    else:
        prompt = """
The following are the result of captioning a group of images:

{captions}

I am a machine learning researcher seeking to elucidate the concept of this group in order to better understand my data.

Come up with 1 distinct concept that is likely to be true for this group. Please write the concept as a single bullet point ("*"). For example: 
* "a dog next to a horse"
* "a car in the rain"
* "low quality"
* "cars from a side view"
* "people in an intricate dress"
* "a joyful atmosphere"

Do not talk about the caption, e.g., "caption with one word" and do not list more than one concept. Also use singular form unless the concept naturally involves multiple objects.
The hypothesis should be a caption, so hypotheses like "more of ...", "presence of ...", "images with ..." are incorrect. Also do not enumerate possibilities within parentheses. Do not provide multiple options by using 'or' or '/' to maintain clarity. Here are examples of bad outputs and their corrections:
* INCORRECT: "various nature environment like lake, forest, and mountain" CORRECTED: "nature"
* INCORRECT: "a image caption of household object (e.g. bowl, vacuum, lamp)" CORRECTED: "a household object"
* INCORRECT: "Presence of baby animal" CORRECTED: "a baby animal"
* INCORRECT: "Different types of vehicles including cars, trucks, boats, and RVs" CORRECTED: "a vehicle"
* INCORRECT: "Image caption involving interaction between humans and animals" CORRECTED: "interaction between humans and animals"
* INCORRECT: "More realistic image" CORRECTED: "realistic image"
* INCORRECT: "Insect (cockroach, dragonfly, grasshopper)" CORRECTED: "an insect"
* INCORRECT: "newspaper or magazine" CORRECTED: "a print media"

Again, I want to identify the characteristic of this group. List a property that holds more often for the images in this group. Answer only with a single bullet point (“*”). Your response:
        """
    
    modality_prompt = prompt.format(captions=bullet_list_captions, num_summaries=num_summaries)
    if verbose:
        print(modality_prompt)
    # diss_system_prompt = f"You will be provided with a theme. Generate a numbered list of 10 themes that are dissimilar to the provided theme. Ensure that the themes you generate are distinct and cover a wide range of unrelated topics. Output only the numbered list."
    expls_list = summarize_sentences(
        summarize_model, modality_prompt, num_summaries, True, model
        )
    
    expls_corrected = correct_explanation(expls_list, stop_words)
                
    
    return expls_corrected

def summarize_sentences(
    model_name: Callable[[str], str],
    prompt: str,
    num_summaries: int,
    verbose: bool = True,
    model: Optional[Any] = None,
) -> Tuple[List[str], List[str]]:
    """Refine a keyphrase by making a call to the llm

    Params
    ------
    llm: Callable[[str], str]
        The llm to use
    ngrams_list: List[str]
        The list of ngrams to summarize
    num_summaries: int
        The number of summaries to generate
    prefix_str: str
        The prefix of the prompt string to use for the llm summarization
    suffix_str: str
        The suffix of the prompt string to use for the llm summarization
    num_top_ngrams_to_use: int
        The number of top ngrams to select
    num_top_ngrams_to_consider: int
        The number of top ngrams to consider selecting from
    seed: int
        The seed to use for the random number generator

    Returns
    -------
    summaries: List[str]
        The list of summaries
    summary_rationales: List[str]
        The list of summary rationales (when available)
    """
   
    if "gpt" in model_name:
        from openai import OpenAI
        model = OpenAI()
        # system_prompt = f"You will be provided with a list of {caption_type}. Please analyze the list and identify the most likely theme that characterizes the majority of these {caption_type}. Consider the possibility of noise in the elements and focus on the frequency and common patterns to determine the theme. Respond with the theme only."
        # Repeat 5 times until you get the number of explanations you need.
        for i in range(5):
            completion = model.chat.completions.create(
            model=model_name,
            messages=[
                {"role": "system", "content": ""},
                {"role": "user", "content": prompt}
            ]
            )
            expl = completion.choices[0].message.content
            split_expl = expl.split("*")[1:]
            split_expl = [expl.strip().replace('"', "") for expl in split_expl] 
            if verbose:
                print(split_expl)
            if len(split_expl) == num_summaries:
                break

        # Make co-hyponyms explanation
        # for i, expl in enumerate(split_expl):
            

        # clean up summary
        # expls_list[seeds_list[i]] = expl
        # expls_rationales_list[seeds_list[i]] = expl_rationale
        # diss_expls_list[seeds_list[i]] = dissimilar_expl_strs

    elif "Llama-3.1" in model_name:
        for i in range(10):
            messages = [{"role": "user", "content": prompt},]
            expl = model(messages, max_length=8192)[0]["generated_text"][-1]["content"]
            split_expl = expl.split("*")[1:]
            split_expl = [expl.strip().replace('"', "") for expl in split_expl] 
            if verbose:
                print(split_expl)
            if len(split_expl) == num_summaries:
                break
            
    return split_expl

def correct_explanation(expls: str, stop_words) -> str:
    # 各文の単語をすべて取り出し、除外単語をフィルタリング
    corrected_expls = []
    for sentence in expls:
        # 単語を小文字にし、特殊文字を取り除いて分割
        sentence_words = re.findall(r'\b\w+\b', sentence.lower())
        # 冠詞を除外してwordsリストに追加
        filtered_words = [word for word in sentence_words if word not in stop_words]
        corrected_expls.append(" ".join(filtered_words))

    # 結果を出力

    return corrected_expls


def load_parser_model_and_tokenizer(parser_checkpoint: str, device: str):
    """
    Explanation:
    Loads the parser model (such as a T5-like model) and tokenizer used for 
    extracting concepts/keywords from image captions.
    """
    parser_tokenizer = AutoTokenizer.from_pretrained(parser_checkpoint)
    parser_model = AutoModelForSeq2SeqLM.from_pretrained(parser_checkpoint)
    parser_model.eval()
    parser_model.to(device)
    return parser_tokenizer, parser_model

def load_lm_model(args):
    """
    Explanation:
    Depending on the chosen caption model, load the corresponding 
    language model (BartForTextInfill) and tokenizer. Also initializes 
    stop token tensors.
    """
    if args.caption_model == "MeaCap":
        lm_model_path = "./MeaCap/checkpoints/CBART_one_billion"
    else:
        lm_model_path = None

    tokenizer = None
    lm_model = None
    stop_tokens_tensor = None
    sub_tokens_tensor = None

    if lm_model_path:
        tokenizer = BartTokenizer.from_pretrained(lm_model_path)
        lm_model = BartForTextInfill.from_pretrained(lm_model_path)
        lm_model = lm_model.to(args.device)
        # Initialize stop token tensors
        stop_tokens_tensor = torch.zeros(tokenizer.vocab_size).to(args.device)
        sub_tokens_tensor = torch.zeros(tokenizer.vocab_size).to(args.device)

    return tokenizer, lm_model, stop_tokens_tensor, sub_tokens_tensor

def load_vl_model(device: str):
    """
    Explanation:
    Loads a vision-language model (CLIP) for computing text representations 
    and evaluating similarity with voxel weights.
    """
    vl_model = CLIP("openai/clip-vit-base-patch32")
    vl_model = vl_model.to(device)
    return vl_model
